{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timsainb/tensorflow2-generative-models/blob/master/2.0-GAN-fashion-mnist.ipynb)\n",
    "\n",
    "## Generative Adversarial Network (GAN)\n",
    "GANs are a form of neural network in which two sub-networks (the encoder and decoder) are trained on opposing loss functions: an encoder that is trained to produce data which is indiscernable from the true data, and a decoder that is trained to discern between the data and generated data.\n",
    "\n",
    "![gan](imgs/gan.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install packages if in colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### install necessary packages if in colab\n",
    "def run_subprocess_command(cmd):\n",
    "    process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE)\n",
    "    for line in process.stdout:\n",
    "        print(line.decode().strip())\n",
    "\n",
    "\n",
    "import sys, subprocess\n",
    "\n",
    "IN_COLAB = \"google.colab\" in sys.modules\n",
    "colab_requirements = [\"pip install tf-nightly-gpu-2.0-preview==2.0.0.dev20190513\"]\n",
    "if IN_COLAB:\n",
    "    for i in colab_requirements:\n",
    "        run_subprocess_command(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:37.894828Z",
     "start_time": "2019-05-14T05:33:37.890997Z"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.336960Z",
     "start_time": "2019-05-14T05:33:37.896917Z"
    }
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.autonotebook import tqdm\n",
    "%matplotlib inline\n",
    "from IPython import display\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.339755Z",
     "start_time": "2019-05-14T05:33:37.472Z"
    }
   },
   "outputs": [],
   "source": [
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a fashion-MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.340986Z",
     "start_time": "2019-05-14T05:33:37.475Z"
    }
   },
   "outputs": [],
   "source": [
    "TRAIN_BUF=60000\n",
    "BATCH_SIZE=512\n",
    "TEST_BUF=10000\n",
    "DIMS = (28,28,1)\n",
    "N_TRAIN_BATCHES =int(TRAIN_BUF/BATCH_SIZE)\n",
    "N_TEST_BATCHES = int(TEST_BUF/BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.342444Z",
     "start_time": "2019-05-14T05:33:37.478Z"
    }
   },
   "outputs": [],
   "source": [
    "# load dataset\n",
    "(train_images, _), (test_images, _) = tf.keras.datasets.fashion_mnist.load_data()\n",
    "\n",
    "# split dataset\n",
    "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype(\n",
    "    \"float32\"\n",
    ") / 255.0\n",
    "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype(\"float32\") / 255.0\n",
    "\n",
    "# batch datasets\n",
    "train_dataset = (\n",
    "    tf.data.Dataset.from_tensor_slices(train_images)\n",
    "    .shuffle(TRAIN_BUF)\n",
    "    .batch(BATCH_SIZE)\n",
    ")\n",
    "test_dataset = (\n",
    "    tf.data.Dataset.from_tensor_slices(test_images)\n",
    "    .shuffle(TEST_BUF)\n",
    "    .batch(BATCH_SIZE)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the network as tf.keras.model object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.343851Z",
     "start_time": "2019-05-14T05:33:37.481Z"
    }
   },
   "outputs": [],
   "source": [
    "class GAN(tf.keras.Model):\n",
    "    \"\"\" a basic GAN class \n",
    "    Extends:\n",
    "        tf.keras.Model\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, **kwargs):\n",
    "        super(GAN, self).__init__()\n",
    "        self.__dict__.update(kwargs)\n",
    "\n",
    "        self.gen = tf.keras.Sequential(self.gen)\n",
    "        self.disc = tf.keras.Sequential(self.disc)\n",
    "\n",
    "    def generate(self, z):\n",
    "        return self.gen(z)\n",
    "\n",
    "    def discriminate(self, x):\n",
    "        return self.disc(x)\n",
    "\n",
    "    def compute_loss(self, x):\n",
    "        \"\"\" passes through the network and computes loss\n",
    "        \"\"\"\n",
    "        # generating noise from a uniform distribution\n",
    "        z_samp = tf.random.normal([x.shape[0], 1, 1, self.n_Z])\n",
    "\n",
    "        # run noise through generator\n",
    "        x_gen = self.generate(z_samp)\n",
    "        # discriminate x and x_gen\n",
    "        logits_x = self.discriminate(x)\n",
    "        logits_x_gen = self.discriminate(x_gen)\n",
    "        ### losses\n",
    "        # losses of real with label \"1\"\n",
    "        disc_real_loss = gan_loss(logits=logits_x, is_real=True)\n",
    "        # losses of fake with label \"0\"\n",
    "        disc_fake_loss = gan_loss(logits=logits_x_gen, is_real=False)\n",
    "        disc_loss = disc_fake_loss + disc_real_loss\n",
    "\n",
    "        # losses of fake with label \"1\"\n",
    "        gen_loss = gan_loss(logits=logits_x_gen, is_real=True)\n",
    "\n",
    "        return disc_loss, gen_loss\n",
    "\n",
    "    def compute_gradients(self, x):\n",
    "        \"\"\" passes through the network and computes loss\n",
    "        \"\"\"\n",
    "        ### pass through network\n",
    "        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
    "            disc_loss, gen_loss = self.compute_loss(x)\n",
    "\n",
    "        # compute gradients\n",
    "        gen_gradients = gen_tape.gradient(gen_loss, self.gen.trainable_variables)\n",
    "        disc_gradients = disc_tape.gradient(disc_loss, self.disc.trainable_variables)\n",
    "\n",
    "        return gen_gradients, disc_gradients\n",
    "\n",
    "    def apply_gradients(self, gen_gradients, disc_gradients):\n",
    "\n",
    "        self.gen_optimizer.apply_gradients(\n",
    "            zip(gen_gradients, self.gen.trainable_variables)\n",
    "        )\n",
    "        self.disc_optimizer.apply_gradients(\n",
    "            zip(disc_gradients, self.disc.trainable_variables)\n",
    "        )\n",
    "    @tf.function\n",
    "    def train(self, train_x):\n",
    "        gen_gradients, disc_gradients = self.compute_gradients(train_x)\n",
    "        self.apply_gradients(gen_gradients, disc_gradients)\n",
    "        \n",
    "        \n",
    "def gan_loss(logits, is_real=True):\n",
    "    \"\"\"Computes standard gan loss between logits and labels\n",
    "    \"\"\"\n",
    "    if is_real:\n",
    "        labels = tf.ones_like(logits)\n",
    "    else:\n",
    "        labels = tf.zeros_like(logits)\n",
    "\n",
    "    return tf.compat.v1.losses.sigmoid_cross_entropy(\n",
    "        multi_class_labels=labels, logits=logits\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the network architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.345202Z",
     "start_time": "2019-05-14T05:33:37.484Z"
    }
   },
   "outputs": [],
   "source": [
    "N_Z = 64\n",
    "generator = [\n",
    "    tf.keras.layers.Dense(units=7 * 7 * 64, activation=\"relu\"),\n",
    "    tf.keras.layers.Reshape(target_shape=(7, 7, 64)),\n",
    "    tf.keras.layers.Conv2DTranspose(\n",
    "        filters=64, kernel_size=3, strides=(2, 2), padding=\"SAME\", activation=\"relu\"\n",
    "    ),\n",
    "    tf.keras.layers.Conv2DTranspose(\n",
    "        filters=32, kernel_size=3, strides=(2, 2), padding=\"SAME\", activation=\"relu\"\n",
    "    ),\n",
    "    tf.keras.layers.Conv2DTranspose(\n",
    "        filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\", activation=\"sigmoid\"\n",
    "    ),\n",
    "]\n",
    "\n",
    "discriminator = [\n",
    "    tf.keras.layers.InputLayer(input_shape=DIMS),\n",
    "    tf.keras.layers.Conv2D(\n",
    "        filters=32, kernel_size=3, strides=(2, 2), activation=\"relu\"\n",
    "    ),\n",
    "    tf.keras.layers.Conv2D(\n",
    "        filters=64, kernel_size=3, strides=(2, 2), activation=\"relu\"\n",
    "    ),\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(units=1, activation=None),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-10T18:40:40.306731Z",
     "start_time": "2019-05-10T18:40:40.292930Z"
    }
   },
   "source": [
    "### Create Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.346464Z",
     "start_time": "2019-05-14T05:33:37.487Z"
    }
   },
   "outputs": [],
   "source": [
    "# optimizers\n",
    "gen_optimizer = tf.keras.optimizers.Adam(0.001, beta_1=0.5)\n",
    "disc_optimizer = tf.keras.optimizers.RMSprop(0.005)# train the model\n",
    "# model\n",
    "model = GAN(\n",
    "    gen = generator,\n",
    "    disc = discriminator,\n",
    "    gen_optimizer = gen_optimizer,\n",
    "    disc_optimizer = disc_optimizer,\n",
    "    n_Z = N_Z\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.347744Z",
     "start_time": "2019-05-14T05:33:37.489Z"
    }
   },
   "outputs": [],
   "source": [
    "# exampled data for plotting results\n",
    "def plot_reconstruction(model, nex=8, zm=2):\n",
    "    samples = model.generate(tf.random.normal(shape=(BATCH_SIZE, N_Z)))\n",
    "    fig, axs = plt.subplots(ncols=nex, nrows=1, figsize=(zm * nex, zm))\n",
    "    for axi in range(nex):\n",
    "        axs[axi].matshow(\n",
    "                    samples.numpy()[axi].squeeze(), cmap=plt.cm.Greys, vmin=0, vmax=1\n",
    "                )\n",
    "        axs[axi].axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.349152Z",
     "start_time": "2019-05-14T05:33:37.492Z"
    }
   },
   "outputs": [],
   "source": [
    "# a pandas dataframe to save the loss information to\n",
    "losses = pd.DataFrame(columns = ['disc_loss', 'gen_loss'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-14T05:33:41.350437Z",
     "start_time": "2019-05-14T05:33:37.496Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 50\n",
    "for epoch in range(n_epochs):\n",
    "    # train\n",
    "    for batch, train_x in tqdm(\n",
    "        zip(range(N_TRAIN_BATCHES), train_dataset), total=N_TRAIN_BATCHES\n",
    "    ):\n",
    "        model.train(train_x)\n",
    "    # test on holdout\n",
    "    loss = []\n",
    "    for batch, test_x in tqdm(\n",
    "        zip(range(N_TEST_BATCHES), test_dataset), total=N_TEST_BATCHES\n",
    "    ):\n",
    "        loss.append(model.compute_loss(train_x))\n",
    "    losses.loc[len(losses)] = np.mean(loss, axis=0)\n",
    "    # plot results\n",
    "    display.clear_output()\n",
    "    print(\n",
    "        \"Epoch: {} | disc_loss: {} | gen_loss: {}\".format(\n",
    "            epoch, losses.disc_loss.values[-1], losses.gen_loss.values[-1]\n",
    "        )\n",
    "    )\n",
    "    plot_reconstruction(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
